import asyncio
import datetime
import logging
from datetime import timedelta
from typing import Any, Dict, Optional, cast

import requests
from paramiko.pkey import PKey
from paramiko.ssh_exception import PasswordRequiredException
from pydantic import ValidationError
from sqlalchemy import and_, delete, func, not_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

from dstack._internal import settings
from dstack._internal.core.backends.base.compute import (
    ComputeWithCreateInstanceSupport,
    ComputeWithPlacementGroupSupport,
    GoArchType,
    get_dstack_runner_binary_path,
    get_dstack_runner_download_url,
    get_dstack_runner_version,
    get_dstack_shim_binary_path,
    get_dstack_working_dir,
    get_shim_env,
    get_shim_pre_start_commands,
)
from dstack._internal.core.backends.features import (
    BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
    BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT,
)
from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT

# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute
from dstack._internal.core.errors import (
    BackendError,
    NotYetTerminated,
    ProvisioningError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.fleets import InstanceGroupPlacement
from dstack._internal.core.models.instances import (
    InstanceAvailability,
    InstanceOfferWithAvailability,
    InstanceRuntime,
    InstanceStatus,
    RemoteConnectionInfo,
    SSHKey,
)
from dstack._internal.core.models.profiles import (
    TerminationPolicy,
)
from dstack._internal.core.models.runs import (
    JobProvisioningData,
)
from dstack._internal.server import settings as server_settings
from dstack._internal.server.background.tasks.common import get_provisioning_timeout
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
    FleetModel,
    InstanceHealthCheckModel,
    InstanceModel,
    JobModel,
    ProjectModel,
)
from dstack._internal.server.schemas.instances import InstanceCheck
from dstack._internal.server.schemas.runner import (
    ComponentStatus,
    HealthcheckResponse,
    InstanceHealthResponse,
)
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services.fleets import (
    fleet_model_to_fleet,
    get_create_instance_offers,
    is_cloud_cluster,
    is_fleet_master_instance,
)
from dstack._internal.server.services.instances import (
    get_instance_configuration,
    get_instance_profile,
    get_instance_provisioning_data,
    get_instance_requirements,
    get_instance_ssh_private_keys,
    remove_dangling_tasks_from_instance,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.offers import (
    get_instance_offer_with_restricted_az,
    is_divisible_into_blocks,
)
from dstack._internal.server.services.placement import (
    find_or_create_suitable_placement_group,
    get_fleet_placement_group_models,
    get_placement_group_model_for_instance,
    placement_group_model_to_placement_group_optional,
    schedule_fleet_placement_groups_deletion,
)
from dstack._internal.server.services.runner import client as runner_client
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
from dstack._internal.server.utils import sentry_utils
from dstack._internal.server.utils.provisioning import (
    detect_cpu_arch,
    get_host_info,
    get_paramiko_connection,
    get_shim_healthcheck,
    host_info_to_instance_type,
    remove_dstack_runner_if_exists,
    remove_host_info_if_exists,
    run_pre_start_commands,
    run_shim_as_systemd_service,
    upload_envs,
)
from dstack._internal.utils.common import (
    get_current_datetime,
    get_or_error,
    run_async,
)
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses
from dstack._internal.utils.ssh import (
    pkey_from_str,
)
from dstack._internal.utils.version import parse_version

MIN_PROCESSING_INTERVAL = timedelta(seconds=10)

PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60)

TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20)
TERMINATION_RETRY_TIMEOUT = timedelta(seconds=30)
TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15)
PROVISIONING_TIMEOUT_SECONDS = 10 * 60  # 10 minutes in seconds


logger = get_logger(__name__)


async def process_instances(batch_size: int = 1):
    tasks = []
    for _ in range(batch_size):
        tasks.append(_process_next_instance())
    await asyncio.gather(*tasks)


@sentry_utils.instrument_background_task
async def delete_instance_health_checks():
    now = get_current_datetime()
    cutoff = now - timedelta(seconds=server_settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS)
    async with get_session_ctx() as session:
        await session.execute(
            delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff)
        )
        await session.commit()


@sentry_utils.instrument_background_task
async def _process_next_instance():
    lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__)
    async with get_session_ctx() as session:
        async with lock:
            res = await session.execute(
                select(InstanceModel)
                .where(
                    InstanceModel.status.in_(
                        [
                            InstanceStatus.PENDING,
                            InstanceStatus.PROVISIONING,
                            InstanceStatus.BUSY,
                            InstanceStatus.IDLE,
                            InstanceStatus.TERMINATING,
                        ]
                    ),
                    # Terminating instances belonging to a compute group
                    # are handled by process_compute_groups.
                    not_(
                        and_(
                            InstanceModel.status == InstanceStatus.TERMINATING,
                            InstanceModel.compute_group_id.is_not(None),
                        )
                    ),
                    InstanceModel.id.not_in(lockset),
                    InstanceModel.last_processed_at
                    < get_current_datetime() - MIN_PROCESSING_INTERVAL,
                )
                .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
                .options(joinedload(InstanceModel.project).load_only(ProjectModel.ssh_private_key))
                .order_by(InstanceModel.last_processed_at.asc())
                .limit(1)
                .with_for_update(skip_locked=True, key_share=True, of=InstanceModel)
            )
            instance = res.scalar()
            if instance is None:
                return
            lockset.add(instance.id)
        instance_model_id = instance.id
        try:
            await _process_instance(session=session, instance=instance)
        finally:
            lockset.difference_update([instance_model_id])


async def _process_instance(session: AsyncSession, instance: InstanceModel):
    # Refetch to load related attributes.
    # Load related attributes only for statuses that always need them.
    if instance.status in (
        InstanceStatus.PENDING,
        InstanceStatus.TERMINATING,
    ):
        res = await session.execute(
            select(InstanceModel)
            .where(InstanceModel.id == instance.id)
            .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
            .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
            .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
            .execution_options(populate_existing=True)
        )
        instance = res.unique().scalar_one()
    elif instance.status == InstanceStatus.IDLE:
        res = await session.execute(
            select(InstanceModel)
            .where(InstanceModel.id == instance.id)
            .options(joinedload(InstanceModel.project))
            .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
            .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
            .execution_options(populate_existing=True)
        )
        instance = res.unique().scalar_one()

    if instance.status == InstanceStatus.PENDING:
        if instance.remote_connection_info is not None:
            await _add_remote(instance)
        else:
            await _create_instance(
                session=session,
                instance=instance,
            )
    elif instance.status in (
        InstanceStatus.PROVISIONING,
        InstanceStatus.IDLE,
        InstanceStatus.BUSY,
    ):
        idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired(instance)
        if not idle_duration_expired:
            await _check_instance(session, instance)
    elif instance.status == InstanceStatus.TERMINATING:
        await _terminate(instance)

    instance.last_processed_at = get_current_datetime()
    await session.commit()


def _check_and_mark_terminating_if_idle_duration_expired(instance: InstanceModel):
    if not (
        instance.status == InstanceStatus.IDLE
        and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
        and not instance.jobs
    ):
        return False
    if instance.fleet is not None and not _can_terminate_fleet_instances_on_idle_duration(
        instance.fleet
    ):
        logger.debug(
            "Skipping instance %s termination on idle duration. Fleet is already at `nodes.min`.",
            instance.name,
        )
        return False
    idle_duration = _get_instance_idle_duration(instance)
    idle_seconds = instance.termination_idle_time
    delta = datetime.timedelta(seconds=idle_seconds)
    if idle_duration > delta:
        instance.status = InstanceStatus.TERMINATING
        instance.termination_reason = "Idle timeout"
        logger.info(
            "Instance %s idle duration expired: idle time %ss. Terminating",
            instance.name,
            str(idle_duration.seconds),
            extra={
                "instance_name": instance.name,
                "instance_status": instance.status.value,
            },
        )
        return True
    return False


def _can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> bool:
    # Do not terminate instances on idle duration if fleet is already at `nodes.min`.
    # This is an optimization to avoid terminate-create loop.
    # There may be race conditions since we don't take the fleet lock.
    # That's ok: in the worst case we go below `nodes.min`, but
    # the fleet consolidation logic will provision new nodes.
    fleet = fleet_model_to_fleet(fleet_model)
    if fleet.spec.configuration.nodes is None or fleet.spec.autocreated:
        return True
    active_instances = [i for i in fleet_model.instances if i.status.is_active()]
    active_instances_num = len(active_instances)
    return active_instances_num > fleet.spec.configuration.nodes.min


async def _add_remote(instance: InstanceModel) -> None:
    logger.info("Adding ssh instance %s...", instance.name)
    if instance.status == InstanceStatus.PENDING:
        instance.status = InstanceStatus.PROVISIONING

    retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
    if retry_duration_deadline < get_current_datetime():
        instance.status = InstanceStatus.TERMINATED
        instance.termination_reason = "Provisioning timeout expired"
        logger.warning(
            "Failed to start instance %s in %d seconds. Terminating...",
            instance.name,
            PROVISIONING_TIMEOUT_SECONDS,
            extra={
                "instance_name": instance.name,
                "instance_status": InstanceStatus.TERMINATED.value,
            },
        )
        return

    try:
        remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info))
        # Prepare connection key
        try:
            pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys)
            if remote_details.ssh_proxy_keys is not None:
                ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys)
            else:
                ssh_proxy_pkeys = None
        except (ValueError, PasswordRequiredException):
            instance.status = InstanceStatus.TERMINATED
            instance.termination_reason = "Unsupported private SSH key type"
            logger.warning(
                "Failed to add instance %s: unsupported private SSH key type",
                instance.name,
                extra={
                    "instance_name": instance.name,
                    "instance_status": InstanceStatus.TERMINATED.value,
                },
            )
            return

        authorized_keys = [pk.public.strip() for pk in remote_details.ssh_keys]
        authorized_keys.append(instance.project.ssh_public_key.strip())

        try:
            future = run_async(
                _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys
            )
            deploy_timeout = 20 * 60  # 20 minutes
            result = await asyncio.wait_for(future, timeout=deploy_timeout)
            health, host_info, arch = result
        except (asyncio.TimeoutError, TimeoutError) as e:
            raise ProvisioningError(f"Deploy timeout: {e}") from e
        except Exception as e:
            raise ProvisioningError(f"Deploy instance raised an error: {e}") from e
        else:
            logger.info(
                "The instance %s (%s) was successfully added",
                instance.name,
                remote_details.host,
            )
    except ProvisioningError as e:
        logger.warning(
            "Provisioning instance %s could not be completed because of the error: %s",
            instance.name,
            e,
        )
        instance.status = InstanceStatus.PENDING
        return

    instance_type = host_info_to_instance_type(host_info, arch)
    instance_network = None
    internal_ip = None
    try:
        default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data)
        instance_network = default_jpd.instance_network
        internal_ip = default_jpd.internal_ip
    except ValidationError:
        pass

    host_network_addresses = host_info.get("addresses", [])
    if internal_ip is None:
        internal_ip = get_ip_from_network(
            network=instance_network,
            addresses=host_network_addresses,
        )
    if instance_network is not None and internal_ip is None:
        instance.status = InstanceStatus.TERMINATED
        instance.termination_reason = "Failed to locate internal IP address on the given network"
        logger.warning(
            "Failed to add instance %s: failed to locate internal IP address on the given network",
            instance.name,
            extra={
                "instance_name": instance.name,
                "instance_status": InstanceStatus.TERMINATED.value,
            },
        )
        return
    if internal_ip is not None:
        if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses):
            instance.status = InstanceStatus.TERMINATED
            instance.termination_reason = (
                "Specified internal IP not found among instance interfaces"
            )
            logger.warning(
                "Failed to add instance %s: specified internal IP not found among instance interfaces",
                instance.name,
                extra={
                    "instance_name": instance.name,
                    "instance_status": InstanceStatus.TERMINATED.value,
                },
            )
            return

    divisible, blocks = is_divisible_into_blocks(
        cpu_count=instance_type.resources.cpus,
        gpu_count=len(instance_type.resources.gpus),
        blocks="auto" if instance.total_blocks is None else instance.total_blocks,
    )
    if divisible:
        instance.total_blocks = blocks
    else:
        instance.status = InstanceStatus.TERMINATED
        instance.termination_reason = "Cannot split into blocks"
        logger.warning(
            "Failed to add instance %s: cannot split into blocks",
            instance.name,
            extra={
                "instance_name": instance.name,
                "instance_status": InstanceStatus.TERMINATED.value,
            },
        )
        return

    region = instance.region
    assert region is not None  # always set for ssh instances
    jpd = JobProvisioningData(
        backend=BackendType.REMOTE,
        instance_type=instance_type,
        instance_id="instance_id",
        hostname=remote_details.host,
        region=region,
        price=0,
        internal_ip=internal_ip,
        instance_network=instance_network,
        username=remote_details.ssh_user,
        ssh_port=remote_details.port,
        dockerized=True,
        backend_data=None,
        ssh_proxy=remote_details.ssh_proxy,
    )

    instance.status = InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING
    instance.backend = BackendType.REMOTE
    instance_offer = InstanceOfferWithAvailability(
        backend=BackendType.REMOTE,
        instance=instance_type,
        region=region,
        price=0,
        availability=InstanceAvailability.AVAILABLE,
        instance_runtime=InstanceRuntime.SHIM,
    )
    instance.price = 0
    instance.offer = instance_offer.json()
    instance.job_provisioning_data = jpd.json()
    instance.started_at = get_current_datetime()


def _deploy_instance(
    remote_details: RemoteConnectionInfo,
    pkeys: list[PKey],
    ssh_proxy_pkeys: Optional[list[PKey]],
    authorized_keys: list[str],
) -> tuple[InstanceCheck, dict[str, Any], GoArchType]:
    with get_paramiko_connection(
        remote_details.ssh_user,
        remote_details.host,
        remote_details.port,
        pkeys,
        remote_details.ssh_proxy,
        ssh_proxy_pkeys,
    ) as client:
        logger.info(f"Connected to {remote_details.ssh_user} {remote_details.host}")

        arch = detect_cpu_arch(client)
        logger.info("%s: CPU arch is %s", remote_details.host, arch)

        # Execute pre start commands
        shim_pre_start_commands = get_shim_pre_start_commands(arch=arch)
        run_pre_start_commands(client, shim_pre_start_commands, authorized_keys)
        logger.debug("The script for installing dstack has been executed")

        # Upload envs
        shim_envs = get_shim_env(authorized_keys, arch=arch)
        try:
            fleet_configuration_envs = remote_details.env.as_dict()
        except ValueError as e:
            raise ProvisioningError(f"Invalid Env: {e}") from e
        shim_envs.update(fleet_configuration_envs)
        dstack_working_dir = get_dstack_working_dir()
        dstack_shim_binary_path = get_dstack_shim_binary_path()
        dstack_runner_binary_path = get_dstack_runner_binary_path()
        upload_envs(client, dstack_working_dir, shim_envs)
        logger.debug("The dstack-shim environment variables have been installed")

        # Ensure we have fresh versions of host info.json and dstack-runner
        remove_host_info_if_exists(client, dstack_working_dir)
        remove_dstack_runner_if_exists(client, dstack_runner_binary_path)

        # Run dstack-shim as a systemd service
        run_shim_as_systemd_service(
            client=client,
            binary_path=dstack_shim_binary_path,
            working_dir=dstack_working_dir,
            dev=settings.DSTACK_VERSION is None,
        )

        # Get host info
        host_info = get_host_info(client, dstack_working_dir)
        logger.debug("Received a host_info %s", host_info)

        healthcheck_out = get_shim_healthcheck(client)
        try:
            healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out)
        except ValueError as e:
            raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e
        instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck)

        return instance_check, host_info, arch


async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
    if _need_to_wait_fleet_provisioning(instance):
        logger.debug("Waiting for the first instance in the fleet to be provisioned")
        return

    try:
        instance_configuration = get_instance_configuration(instance)
        profile = get_instance_profile(instance)
        requirements = get_instance_requirements(instance)
    except ValidationError as e:
        instance.status = InstanceStatus.TERMINATED
        instance.termination_reason = (
            f"Error to parse profile, requirements or instance_configuration: {e}"
        )
        logger.warning(
            "Error to parse profile, requirements or instance_configuration. Terminate instance: %s",
            instance.name,
            extra={
                "instance_name": instance.name,
                "instance_status": InstanceStatus.TERMINATED.value,
            },
        )
        return

    # The placement group is determined when provisioning the master instance
    # and used for all other instances in the fleet.
    placement_group_models = await get_fleet_placement_group_models(
        session=session,
        fleet_id=instance.fleet_id,
    )
    placement_group_model = get_placement_group_model_for_instance(
        placement_group_models=placement_group_models,
        instance_model=instance,
    )
    offers = await get_create_instance_offers(
        project=instance.project,
        profile=profile,
        requirements=requirements,
        fleet_model=instance.fleet,
        placement_group=placement_group_model_to_placement_group_optional(placement_group_model),
        blocks="auto" if instance.total_blocks is None else instance.total_blocks,
        exclude_not_available=True,
    )

    # Limit number of offers tried to prevent long-running processing
    # in case all offers fail.
    for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]:
        if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT:
            continue
        compute = backend.compute()
        assert isinstance(compute, ComputeWithCreateInstanceSupport)
        instance_offer = _get_instance_offer_for_instance(instance_offer, instance)
        if (
            instance.fleet
            and is_cloud_cluster(instance.fleet)
            and is_fleet_master_instance(instance)
            and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
            and isinstance(compute, ComputeWithPlacementGroupSupport)
            and (
                compute.are_placement_groups_compatible_with_reservations(instance_offer.backend)
                or instance_configuration.reservation is None
            )
        ):
            placement_group_model = await find_or_create_suitable_placement_group(
                fleet_model=instance.fleet,
                placement_groups=placement_group_models,
                instance_offer=instance_offer,
                compute=compute,
            )
            if placement_group_model is None:  # error occurred
                continue
            session.add(placement_group_model)
            placement_group_models.append(placement_group_model)
        logger.debug(
            "Trying %s in %s/%s for $%0.4f per hour",
            instance_offer.instance.name,
            instance_offer.backend.value,
            instance_offer.region,
            instance_offer.price,
        )
        try:
            job_provisioning_data = await run_async(
                compute.create_instance,
                instance_offer,
                instance_configuration,
                placement_group_model_to_placement_group_optional(placement_group_model),
            )
        except BackendError as e:
            logger.warning(
                "%s launch in %s/%s failed: %s",
                instance_offer.instance.name,
                instance_offer.backend.value,
                instance_offer.region,
                repr(e),
                extra={"instance_name": instance.name},
            )
            continue
        except Exception:
            logger.exception(
                "Got exception when launching %s in %s/%s",
                instance_offer.instance.name,
                instance_offer.backend.value,
                instance_offer.region,
            )
            continue

        instance.status = InstanceStatus.PROVISIONING
        instance.backend = backend.TYPE
        instance.region = instance_offer.region
        instance.price = instance_offer.price
        instance.instance_configuration = instance_configuration.json()
        instance.job_provisioning_data = job_provisioning_data.json()
        instance.offer = instance_offer.json()
        instance.total_blocks = instance_offer.total_blocks
        instance.started_at = get_current_datetime()

        logger.info(
            "Created instance %s",
            instance.name,
            extra={
                "instance_name": instance.name,
                "instance_status": InstanceStatus.PROVISIONING.value,
            },
        )
        if instance.fleet_id and is_fleet_master_instance(instance):
            # Clean up placement groups that did not end up being used.
            # Flush to update still uncommitted placement groups.
            await session.flush()
            await schedule_fleet_placement_groups_deletion(
                session=session,
                fleet_id=instance.fleet_id,
                except_placement_group_ids=(
                    [placement_group_model.id] if placement_group_model is not None else []
                ),
            )
        return

    _mark_terminated(instance, "All offers failed" if offers else "No offers found")
    if instance.fleet and is_fleet_master_instance(instance) and is_cloud_cluster(instance.fleet):
        # Do not attempt to deploy other instances, as they won't determine the correct cluster
        # backend, region, and placement group without a successfully deployed master instance
        for sibling_instance in instance.fleet.instances:
            if sibling_instance.id == instance.id:
                continue
            _mark_terminated(sibling_instance, "Master instance failed to start")


def _mark_terminated(instance: InstanceModel, termination_reason: str) -> None:
    instance.status = InstanceStatus.TERMINATED
    instance.termination_reason = termination_reason
    logger.info(
        "Terminated instance %s: %s",
        instance.name,
        instance.termination_reason,
        extra={
            "instance_name": instance.name,
            "instance_status": InstanceStatus.TERMINATED.value,
        },
    )


async def _check_instance(session: AsyncSession, instance: InstanceModel) -> None:
    if (
        instance.status == InstanceStatus.BUSY
        and instance.jobs
        and all(job.status.is_finished() for job in instance.jobs)
    ):
        # A busy instance could have no active jobs due to this bug: https://github.com/dstackai/dstack/issues/2068
        instance.status = InstanceStatus.TERMINATING
        instance.termination_reason = "Instance job finished"
        logger.info(
            "Detected busy instance %s with finished job. Marked as TERMINATING",
            instance.name,
            extra={
                "instance_name": instance.name,
                "instance_status": instance.status.value,
            },
        )
        return

    job_provisioning_data = get_or_error(get_instance_provisioning_data(instance))
    if job_provisioning_data.hostname is None:
        res = await session.execute(
            select(ProjectModel)
            .where(ProjectModel.id == instance.project_id)
            .options(joinedload(ProjectModel.backends))
        )
        project = res.unique().scalar_one()
        await _wait_for_instance_provisioning_data(
            project=project,
            instance=instance,
            job_provisioning_data=job_provisioning_data,
        )
        return

    if not job_provisioning_data.dockerized:
        if instance.status == InstanceStatus.PROVISIONING:
            instance.status = InstanceStatus.BUSY
        return

    ssh_private_keys = get_instance_ssh_private_keys(instance)

    health_check_cutoff = get_current_datetime() - timedelta(
        seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS
    )
    res = await session.execute(
        select(func.count(1)).where(
            InstanceHealthCheckModel.instance_id == instance.id,
            InstanceHealthCheckModel.collected_at > health_check_cutoff,
        )
    )
    check_instance_health = res.scalar_one() == 0

    # May return False if fails to establish ssh connection
    instance_check = await run_async(
        _check_instance_inner,
        ssh_private_keys,
        job_provisioning_data,
        None,
        instance=instance,
        check_instance_health=check_instance_health,
    )
    if instance_check is False:
        instance_check = InstanceCheck(reachable=False, message="SSH or tunnel error")

    if instance_check.reachable and check_instance_health:
        health_status = instance_check.get_health_status()
    else:
        # Keep previous health status
        health_status = instance.health

    loglevel = logging.DEBUG
    if not instance_check.reachable and instance.status.is_available():
        loglevel = logging.WARNING
    elif check_instance_health and not health_status.is_healthy():
        loglevel = logging.WARNING
    logger.log(
        loglevel,
        "Instance %s check: reachable=%s health_status=%s message=%r",
        instance.name,
        instance_check.reachable,
        health_status.name,
        instance_check.message,
        extra={"instance_name": instance.name, "health_status": health_status},
    )

    if instance_check.has_health_checks():
        # ensured by has_health_checks()
        assert instance_check.health_response is not None
        health_check_model = InstanceHealthCheckModel(
            instance_id=instance.id,
            collected_at=get_current_datetime(),
            status=health_status,
            response=instance_check.health_response.json(),
        )
        session.add(health_check_model)

    instance.health = health_status
    instance.unreachable = not instance_check.reachable

    if instance_check.reachable:
        instance.termination_deadline = None

        if instance.status == InstanceStatus.PROVISIONING:
            instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
            logger.info(
                "Instance %s has switched to %s status",
                instance.name,
                instance.status.value,
                extra={
                    "instance_name": instance.name,
                    "instance_status": instance.status.value,
                },
            )
        return

    if instance.termination_deadline is None:
        instance.termination_deadline = get_current_datetime() + TERMINATION_DEADLINE_OFFSET

    if instance.status == InstanceStatus.PROVISIONING and instance.started_at is not None:
        provisioning_deadline = _get_provisioning_deadline(
            instance=instance,
            job_provisioning_data=job_provisioning_data,
        )
        if get_current_datetime() > provisioning_deadline:
            instance.status = InstanceStatus.TERMINATING
            logger.warning(
                "Instance %s has not started in time. Marked as TERMINATING",
                instance.name,
                extra={
                    "instance_name": instance.name,
                    "instance_status": InstanceStatus.TERMINATING.value,
                },
            )
    elif instance.status.is_available():
        deadline = instance.termination_deadline
        if get_current_datetime() > deadline:
            instance.status = InstanceStatus.TERMINATING
            instance.termination_reason = "Termination deadline"
            logger.warning(
                "Instance %s shim waiting timeout. Marked as TERMINATING",
                instance.name,
                extra={
                    "instance_name": instance.name,
                    "instance_status": InstanceStatus.TERMINATING.value,
                },
            )


async def _wait_for_instance_provisioning_data(
    project: ProjectModel,
    instance: InstanceModel,
    job_provisioning_data: JobProvisioningData,
):
    logger.debug(
        "Waiting for instance %s to become running",
        instance.name,
    )
    provisioning_deadline = _get_provisioning_deadline(
        instance=instance,
        job_provisioning_data=job_provisioning_data,
    )
    if get_current_datetime() > provisioning_deadline:
        logger.warning(
            "Instance %s failed because instance has not become running in time", instance.name
        )
        instance.status = InstanceStatus.TERMINATING
        instance.termination_reason = "Instance has not become running in time"
        return

    backend = await backends_services.get_project_backend_by_type(
        project=project,
        backend_type=job_provisioning_data.backend,
    )
    if backend is None:
        logger.warning(
            "Instance %s failed because instance's backend is not available",
            instance.name,
        )
        instance.status = InstanceStatus.TERMINATING
        instance.termination_reason = "Backend not available"
        return
    try:
        await run_async(
            backend.compute().update_provisioning_data,
            job_provisioning_data,
            project.ssh_public_key,
            project.ssh_private_key,
        )
        instance.job_provisioning_data = job_provisioning_data.json()
    except ProvisioningError as e:
        logger.warning(
            "Error while waiting for instance %s to become running: %s",
            instance.name,
            repr(e),
        )
        instance.status = InstanceStatus.TERMINATING
        instance.termination_reason = "Error while waiting for instance to become running"
    except Exception:
        logger.exception(
            "Got exception when updating instance %s provisioning data", instance.name
        )


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
def _check_instance_inner(
    ports: Dict[int, int], *, instance: InstanceModel, check_instance_health: bool = False
) -> InstanceCheck:
    instance_health_response: Optional[InstanceHealthResponse] = None
    shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
    method = shim_client.healthcheck
    try:
        healthcheck_response = method(unmask_exceptions=True)
        if check_instance_health:
            method = shim_client.get_instance_health
            instance_health_response = method()
    except requests.RequestException as e:
        template = "shim.%s(): request error: %s"
        args = (method.__func__.__name__, e)
        logger.debug(template, *args)
        return InstanceCheck(reachable=False, message=template % args)
    except Exception as e:
        template = "shim.%s(): unexpected exception %s: %s"
        args = (method.__func__.__name__, e.__class__.__name__, e)
        logger.exception(template, *args)
        return InstanceCheck(reachable=False, message=template % args)

    _maybe_update_runner(instance, shim_client)

    try:
        remove_dangling_tasks_from_instance(shim_client, instance)
    except Exception as e:
        logger.exception("%s: error removing dangling tasks: %s", fmt(instance), e)

    return runner_client.healthcheck_response_to_instance_check(
        healthcheck_response, instance_health_response
    )


def _maybe_update_runner(instance: InstanceModel, shim_client: runner_client.ShimClient) -> None:
    # To auto-update to the latest runner dev build from the CI, see DSTACK_USE_LATEST_FROM_BRANCH.
    expected_version_str = get_dstack_runner_version()
    try:
        expected_version = parse_version(expected_version_str)
    except ValueError as e:
        logger.warning("Failed to parse expected runner version: %s", e)
        return
    if expected_version is None:
        logger.debug("Cannot determine the expected runner version")
        return

    try:
        runner_info = shim_client.get_runner_info()
    except requests.RequestException as e:
        logger.warning("Instance %s: shim.get_runner_info(): request error: %s", instance.name, e)
        return
    if runner_info is None:
        logger.debug("Instance %s: no runner info", instance.name)
        return

    logger.debug(
        "Instance %s: runner status=%s version=%s",
        instance.name,
        runner_info.status.value,
        runner_info.version,
    )
    if runner_info.status == ComponentStatus.INSTALLING:
        return

    if runner_info.version:
        try:
            current_version = parse_version(runner_info.version)
        except ValueError as e:
            logger.warning("Instance %s: failed to parse runner version: %s", instance.name, e)
            return

        if current_version is None or current_version >= expected_version:
            logger.debug("Instance %s: the latest runner version already installed", instance.name)
            return

        logger.debug(
            "Instance %s: updating runner %s -> %s",
            instance.name,
            current_version,
            expected_version,
        )
    else:
        logger.debug("Instance %s: installing runner %s", instance.name, expected_version)

    job_provisioning_data = get_or_error(get_instance_provisioning_data(instance))
    url = get_dstack_runner_download_url(
        arch=job_provisioning_data.instance_type.resources.cpu_arch, version=expected_version_str
    )
    try:
        shim_client.install_runner(url)
    except requests.RequestException as e:
        logger.warning("Instance %s: shim.install_runner(): %s", instance.name, e)


async def _terminate(instance: InstanceModel) -> None:
    if (
        instance.last_termination_retry_at is not None
        and _next_termination_retry_at(instance) > get_current_datetime()
    ):
        return
    jpd = get_instance_provisioning_data(instance)
    if jpd is not None and jpd.backend != BackendType.REMOTE:
        backend = await backends_services.get_project_backend_by_type(
            project=instance.project, backend_type=jpd.backend
        )
        if backend is None:
            logger.error(
                "Failed to terminate instance %s. Backend %s not available.",
                instance.name,
                jpd.backend,
            )
        else:
            logger.debug("Terminating runner instance %s", jpd.hostname)
            try:
                await run_async(
                    backend.compute().terminate_instance,
                    jpd.instance_id,
                    jpd.region,
                    jpd.backend_data,
                )
            except Exception as e:
                if instance.first_termination_retry_at is None:
                    instance.first_termination_retry_at = get_current_datetime()
                instance.last_termination_retry_at = get_current_datetime()
                if _next_termination_retry_at(instance) < _get_termination_deadline(instance):
                    if isinstance(e, NotYetTerminated):
                        logger.debug("Instance %s termination in progress: %s", instance.name, e)
                    else:
                        logger.warning(
                            "Failed to terminate instance %s. Will retry. Error: %r",
                            instance.name,
                            e,
                            exc_info=not isinstance(e, BackendError),
                        )
                    return
                logger.error(
                    "Failed all attempts to terminate instance %s."
                    " Please terminate the instance manually to avoid unexpected charges."
                    " Error: %r",
                    instance.name,
                    e,
                    exc_info=not isinstance(e, BackendError),
                )

    instance.deleted = True
    instance.deleted_at = get_current_datetime()
    instance.finished_at = get_current_datetime()
    instance.status = InstanceStatus.TERMINATED
    logger.info(
        "Instance %s terminated",
        instance.name,
        extra={
            "instance_name": instance.name,
            "instance_status": InstanceStatus.TERMINATED.value,
        },
    )


def _next_termination_retry_at(instance: InstanceModel) -> datetime.datetime:
    assert instance.last_termination_retry_at is not None
    return instance.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT


def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime:
    assert instance.first_termination_retry_at is not None
    return instance.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION


def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
    # Cluster cloud instances should wait for the first fleet instance to be provisioned
    # so that they are provisioned in the same backend/region
    if instance.fleet is None:
        return False
    if (
        is_fleet_master_instance(instance)
        or instance.fleet.instances[0].job_provisioning_data is not None
        or instance.fleet.instances[0].status == InstanceStatus.TERMINATED
    ):
        return False
    return is_cloud_cluster(instance.fleet)


def _get_instance_offer_for_instance(
    instance_offer: InstanceOfferWithAvailability,
    instance: InstanceModel,
) -> InstanceOfferWithAvailability:
    if instance.fleet is None:
        return instance_offer
    fleet = fleet_model_to_fleet(instance.fleet)
    master_instance = instance.fleet.instances[0]
    master_job_provisioning_data = get_instance_provisioning_data(master_instance)
    if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER:
        return get_instance_offer_with_restricted_az(
            instance_offer=instance_offer,
            master_job_provisioning_data=master_job_provisioning_data,
        )
    return instance_offer


def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta:
    last_time = instance.created_at
    if instance.last_job_processed_at is not None:
        last_time = instance.last_job_processed_at
    return get_current_datetime() - last_time


def _get_provisioning_deadline(
    instance: InstanceModel,
    job_provisioning_data: JobProvisioningData,
) -> datetime.datetime:
    assert instance.started_at is not None
    timeout_interval = get_provisioning_timeout(
        backend_type=job_provisioning_data.get_base_backend(),
        instance_type_name=job_provisioning_data.instance_type.name,
    )
    return instance.started_at + timeout_interval


def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:
    return [pkey_from_str(sk.private) for sk in ssh_keys if sk.private is not None]
